Understanding the worker, part 2

  1. Processing of documents: process_document function is responsible for processing the PDF documents. It uses the PyPDFLoader to load the document, splits the document into chunks using the RecursiveCharacterTextSplitter, and then creates a vector store (Chroma) from the document chunks using the language model embeddings. This vector store is then used to create a retriever interface, which is used to create a ConversationalRetrievalChain.

    • Document loading: The PDF document is loaded using the PyPDFLoader class, which takes the path of the document as an argument. (Todo exercise: assign PyPDFLoader(…) to loader)

    • Document splitting: The loaded document is split into chunks using the RecursiveCharacterTextSplitter class. The chunk_size and overlap can be specified. (Todo exercise: assign RecursiveCharacterTextSplitter(…) to text_splitter)

    • Vector store creation: A vector store, which is a kind of index, is created from the document chunks using the language model embeddings. This is done using the Chroma class.

    • Retrieval system setup: A retrieval system is set up using the vector store. This system, calls a ConversationalRetrievalChain, used to answer questions based on the document content.

To do: complete the blank parts

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  1. # Function to process a PDF document
  2. def process_document(document_path):
  3. global conversation_retrieval_chain
  4. logger.info("Loading document from path: %s", document_path)
  5. # Load the document
  6. loader = # ---> use PyPDFLoader and document_path from the function input parameter <---
  7. documents = loader.load()
  8. logger.debug("Loaded %d document(s)", len(documents))
  9. # Split the document into chunks, set chunk_size=1024, and chunk_overlap=64. assign it to variable text_splitter
  10. text_splitter = # ---> use Recursive Character TextSplitter and specify the input parameters <---
  11. texts = text_splitter.split_documents(documents)
  12. logger.debug("Document split into %d text chunks", len(texts))
  13. # Create an embeddings database using Chroma from the split text chunks.
  14. logger.info("Initializing Chroma vector store from documents...")
  15. db = Chroma.from_documents(texts, embedding=embeddings)
  16. logger.debug("Chroma vector store initialized.")
  17. # Optional: Log available collections if accessible (this may be internal API)
  18. try:
  19. collections = db._client.list_collections() # _client is internal; adjust if needed
  20. logger.debug("Available collections in Chroma: %s", collections)
  21. except Exception as e:
  22. logger.warning("Could not retrieve collections from Chroma: %s", e)
  23. # Build the QA chain, which utilizes the LLM and retriever for answering questions.
  24. conversation_retrieval_chain = RetrievalQA.from_chain_type(
  25. llm=llm_hub,
  26. chain_type="stuff",
  27. retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
  28. return_source_documents=False,
  29. input_key="question"
  30. # chain_type_kwargs={"prompt": prompt} # if you are using a prompt template, uncomment this part
  31. )
  32. logger.info("RetrievalQA chain created successfully.")
Click here to see the solution
  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  1. # Function to process a PDF document
  2. def process_document(document_path):
  3. global conversation_retrieval_chain
  4. logger.info("Loading document from path: %s", document_path)
  5. # Load the document
  6. loader = PyPDFLoader(document_path)
  7. documents = loader.load()
  8. logger.debug("Loaded %d document(s)", len(documents))
  9. # Split the document into chunks
  10. text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
  11. texts = text_splitter.split_documents(documents)
  12. logger.debug("Document split into %d text chunks", len(texts))
  13. # Create an embeddings database using Chroma from the split text chunks.
  14. logger.info("Initializing Chroma vector store from documents...")
  15. db = Chroma.from_documents(texts, embedding=embeddings)
  16. logger.debug("Chroma vector store initialized.")
  17. # Optional: Log available collections if accessible (this may be internal API)
  18. try:
  19. collections = db._client.list_collections() # _client is internal; adjust if needed
  20. logger.debug("Available collections in Chroma: %s", collections)
  21. except Exception as e:
  22. logger.warning("Could not retrieve collections from Chroma: %s", e)
  23. # Build the QA chain, which utilizes the LLM and retriever for answering questions.
  24. conversation_retrieval_chain = RetrievalQA.from_chain_type(
  25. llm=llm_hub,
  26. chain_type="stuff",
  27. retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
  28. return_source_documents=False,
  29. input_key="question"
  30. # chain_type_kwargs={"prompt": prompt} # if you are using a prompt template, uncomment this part
  31. )
  32. logger.info("RetrievalQA chain created successfully.")
  1. Prompt processing (process_prompt function): This function handles a user's prompt or question, retrieves a response based on the contents of the previously processed PDF document, and maintains a chat history. It does the following:

    • Passes the prompt and the chat history to the ConversationalRetrievalChain object. conversation_retrieval_chain is the primary tool used to query the language model and get an answer based on the processed PDF document's contents.
    • Appends the prompt and the bot's response to the chat history.
    • Returns the bot's response.

Here's a skeleton of the process_prompt function for the exercise:

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  1. # Function to process a user prompt
  2. def process_prompt(prompt):
  3. global conversation_retrieval_chain
  4. global chat_history
  5. logger.info("Processing prompt: %s", prompt)
  6. # Query the model using the new .invoke() method
  7. output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
  8. answer = output["result"]
  9. logger.debug("Model response: %s", answer)
  10. # Update the chat history
  11. # TODO: Append the prompt and the bot's response to the chat history using chat_history.append and pass `prompt` `answer` as arguments
  12. # --> write your code here <--
  13. logger.debug("Chat history updated. Total exchanges: %d", len(chat_history))
  14. # Return the model's response
  15. return answer
Click here to see the solution
  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  1. # Function to process a user prompt
  2. def process_prompt(prompt):
  3. global conversation_retrieval_chain
  4. global chat_history
  5. logger.info("Processing prompt: %s", prompt)
  6. # Query the model using the new .invoke() method
  7. output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
  8. answer = output["result"]
  9. logger.debug("Model response: %s", answer)
  10. # Update the chat history
  11. chat_history.append((prompt, answer))
  12. logger.debug("Chat history updated. Total exchanges: %d", len(chat_history))
  13. # Return the model's response
  14. return answer
  1. Global variables:

    • llm and llm_embeddings are used to store the language model and its embeddings conversation_retrieval_chain and chat_history is used to store the chat and history. global is used inside the functions init_llm, process_document, and process_prompt to indicate that the variables llm, llm_embeddings, conversation_retrieval_chain, and chat_history are global variables. This means that when these variables are modified inside these functions, the changes will persist outside the functions as well, affecting the global state of the program.

Here is the complete worker.py. The final code can be found in Worker_completed.py as well.

Click here to see the complete worker.py
  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
  75. 75
  76. 76
  77. 77
  78. 78
  79. 79
  80. 80
  81. 81
  82. 82
  83. 83
  84. 84
  85. 85
  86. 86
  87. 87
  88. 88
  89. 89
  90. 90
  91. 91
  92. 92
  93. 93
  94. 94
  95. 95
  96. 96
  97. 97
  98. 98
  99. 99
  100. 100
  101. 101
  102. 102
  103. 103
  104. 104
  105. 105
  106. 106
  107. 107
  108. 108
  109. 109
  110. 110
  111. 111
  112. 112
  113. 113
  114. 114
  115. 115
  116. 116
  117. 117
  118. 118
  119. 119
  120. 120
  121. 121
  122. 122
  123. 123
  124. 124
  125. 125
  126. 126
  127. 127
  1. import os
  2. import torch
  3. import logging
  4. # Configure logging
  5. logging.basicConfig(level=logging.DEBUG)
  6. logger = logging.getLogger(__name__)
  7. from langchain_core.prompts import PromptTemplate # Updated import per deprecation notice
  8. from langchain.chains import RetrievalQA
  9. from langchain_community.embeddings import HuggingFaceInstructEmbeddings # New import path
  10. from langchain_community.document_loaders import PyPDFLoader # New import path
  11. from langchain.text_splitter import RecursiveCharacterTextSplitter
  12. from langchain_community.vectorstores import Chroma # New import path
  13. from langchain_ibm import WatsonxLLM
  14. # Check for GPU availability and set the appropriate device for computation.
  15. DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
  16. # Global variables
  17. conversation_retrieval_chain = None
  18. chat_history = []
  19. llm_hub = None
  20. embeddings = None
  21. # Function to initialize the language model and its embeddings
  22. def init_llm():
  23. global llm_hub, embeddings
  24. logger.info("Initializing WatsonxLLM and embeddings...")
  25. # Llama Model Configuration
  26. MODEL_ID = "meta-llama/llama-3-3-70b-instruct"
  27. WATSONX_URL = "https://us-south.ml.cloud.ibm.com"
  28. PROJECT_ID = "skills-network"
  29. # Use the same parameters as before:
  30. # MAX_NEW_TOKENS: 256, TEMPERATURE: 0.1
  31. model_parameters = {
  32. # "decoding_method": "greedy",
  33. "max_new_tokens": 256,
  34. "temperature": 0.1,
  35. }
  36. # Initialize Llama LLM using the updated WatsonxLLM API
  37. llm_hub = WatsonxLLM(
  38. model_id=MODEL_ID,
  39. url=WATSONX_URL,
  40. project_id=PROJECT_ID,
  41. params=model_parameters
  42. )
  43. logger.debug("WatsonxLLM initialized: %s", llm_hub)
  44. # Initialize embeddings using a pre-trained model to represent the text data.
  45. ### --> if you are using huggingFace API:
  46. # Set up the environment variable for HuggingFace and initialize the desired model, and load the model into the HuggingFaceHub
  47. # dont forget to remove llm_hub for watsonX
  48. # os.environ["HUGGINGFACEHUB_API_TOKEN"] = "YOUR API KEY"
  49. # model_id = "tiiuae/falcon-7b-instruct"
  50. #llm_hub = HuggingFaceHub(repo_id=model_id, model_kwargs={"temperature": 0.1, "max_new_tokens": 600, "max_length": 600})
  51. embeddings = HuggingFaceInstructEmbeddings(
  52. model_name="sentence-transformers/all-MiniLM-L6-v2",
  53. model_kwargs={"device": DEVICE}
  54. )
  55. logger.debug("Embeddings initialized with model device: %s", DEVICE)
  56. # Function to process a PDF document
  57. def process_document(document_path):
  58. global conversation_retrieval_chain
  59. logger.info("Loading document from path: %s", document_path)
  60. # Load the document
  61. loader = PyPDFLoader(document_path)
  62. documents = loader.load()
  63. logger.debug("Loaded %d document(s)", len(documents))
  64. # Split the document into chunks
  65. text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
  66. texts = text_splitter.split_documents(documents)
  67. logger.debug("Document split into %d text chunks", len(texts))
  68. # Create an embeddings database using Chroma from the split text chunks.
  69. logger.info("Initializing Chroma vector store from documents...")
  70. db = Chroma.from_documents(texts, embedding=embeddings)
  71. logger.debug("Chroma vector store initialized.")
  72. # Optional: Log available collections if accessible (this may be internal API)
  73. try:
  74. collections = db._client.list_collections() # _client is internal; adjust if needed
  75. logger.debug("Available collections in Chroma: %s", collections)
  76. except Exception as e:
  77. logger.warning("Could not retrieve collections from Chroma: %s", e)
  78. # Build the QA chain, which utilizes the LLM and retriever for answering questions.
  79. conversation_retrieval_chain = RetrievalQA.from_chain_type(
  80. llm=llm_hub,
  81. chain_type="stuff",
  82. retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
  83. return_source_documents=False,
  84. input_key="question"
  85. # chain_type_kwargs={"prompt": prompt} # if you are using a prompt template, uncomment this part
  86. )
  87. logger.info("RetrievalQA chain created successfully.")
  88. # Function to process a user prompt
  89. def process_prompt(prompt):
  90. global conversation_retrieval_chain
  91. global chat_history
  92. logger.info("Processing prompt: %s", prompt)
  93. # Query the model using the new .invoke() method
  94. output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
  95. answer = output["result"]
  96. logger.debug("Model response: %s", answer)
  97. # Update the chat history
  98. chat_history.append((prompt, answer))
  99. logger.debug("Chat history updated. Total exchanges: %d", len(chat_history))
  100. # Return the model's response
  101. return answer
  102. # Initialize the language model
  103. init_llm()
  104. logger.info("LLM and embeddings initialization complete.")